import torch
from tqdm import trange

from comfy.samplers import KSAMPLER


def generate_trend_values(steps, start_time, end_time, eta, eta_trend):
    eta_values = [0] * steps
    
    if eta_trend == 'constant':
        for i in range(start_time, end_time):
            eta_values[i] = eta
    elif eta_trend == 'linear_increase':
        for i in range(start_time, end_time):
            progress = (i - start_time) / (end_time - start_time - 1)
            eta_values[i] = eta * progress
    elif eta_trend == 'linear_decrease':
        for i in range(start_time, end_time):
            progress = 1 - (i - start_time) / (end_time - start_time - 1)
            eta_values[i] = eta * progress
    
    return eta_values



def get_sample_forward(gamma, start_step, end_step, gamma_trend, seed, attn_bank=None, order="first"):
    # Controlled Forward ODE (Algorithm 1)
    generator = torch.Generator()
    generator.manual_seed(seed)

    @torch.no_grad()
    def sample_forward(model, y0, sigmas, extra_args=None, callback=None, disable=None):
        if attn_bank is not None:
            for block_idx in attn_bank['block_map']:
                attn_bank['block_map'][block_idx].clear()

        extra_args = {} if extra_args is None else extra_args
        model_options = extra_args.get('model_options', {})
        model_options = {**model_options}
        transformer_options = model_options.get('transformer_options', {})
        transformer_options = {**transformer_options, 'total_steps': len(sigmas)-1, 'sample_mode': 'forward', 'attn_bank': attn_bank}
        model_options['transformer_options'] = transformer_options
        extra_args['model_options'] = model_options

        Y = y0.clone()
        y1 = torch.randn(Y.shape, generator=generator).to(y0.device)
        N = len(sigmas)-1
        s_in = y0.new_ones([y0.shape[0]])
        gamma_values = generate_trend_values(N, start_step, end_step, gamma, gamma_trend)
        for i in trange(N, disable=disable):
            transformer_options['step'] = i
            sigma = sigmas[i]
            sigma_next = sigmas[i+1]
            t_i = model.inner_model.inner_model.model_sampling.timestep(sigmas[i])

            conditional_vector_field = (y1-Y)/(1-t_i)

            transformer_options['pred_order'] = 'first'
            pred = model(Y, s_in * sigmas[i], **extra_args) # this implementation takes sigma instead of timestep
            
            if order == 'second':
                transformer_options['pred_order'] = 'second'
                img_mid = Y + (sigma_next- sigma) / 2 * pred
                sigma_mid = (sigma + (sigma_next - sigma) / 2)
                pred_mid = model(img_mid, s_in * sigma_mid, **extra_args)

                first_order = (pred_mid - pred) / ((sigma_next - sigma) / 2)
                pred = pred + gamma_values[i] * (conditional_vector_field - pred)
                # first_order = first_order + gamma_values[i] * (conditional_vector_field - first_order)
                Y = Y + (sigma_next - sigma) * pred + 0.5 * (sigma_next - sigma) ** 2 * first_order
            else:
                pred = pred + gamma_values[i] * (conditional_vector_field - pred)
                Y = Y + pred * (sigma_next - sigma)

            if callback is not None:
                callback({'x': Y, 'denoised': Y, 'i': i, 'sigma': sigma, 'sigma_hat': sigma})

        return Y

    return sample_forward


def get_sample_reverse(latent_image, eta, start_time, end_time, eta_trend, attn_bank=None, order='first'):
    # Controlled Reverse ODE (Algorithm 2)
    @torch.no_grad()
    def sample_reverse(model, y1, sigmas, extra_args=None, callback=None, disable=None):
        extra_args = {} if extra_args is None else extra_args
        model_options = extra_args.get('model_options', {})
        model_options = {**model_options}
        transformer_options = model_options.get('transformer_options', {})
        transformer_options = {**transformer_options, 'total_steps': len(sigmas)-1, 'sample_mode': 'reverse', 'attn_bank': attn_bank}
        model_options['transformer_options'] = transformer_options
        extra_args['model_options'] = model_options

        X = y1.clone()
        N = len(sigmas)-1
        y0 = latent_image.clone().to(y1.device)
        s_in = y0.new_ones([y0.shape[0]])
        eta_values = generate_trend_values(N, start_time, end_time, eta, eta_trend)
        for i in trange(N, disable=disable):
            transformer_options['step'] = i
            t_i = 1-model.inner_model.inner_model.model_sampling.timestep(sigmas[i])
            sigma = sigmas[i]
            sigma_prev = sigmas[i+1]

            conditional_vector_field = (y0-X)/(1-t_i)

            transformer_options['pred_order'] = 'first'
            pred = model(X, sigma*s_in, **extra_args) # this implementation takes sigma instead of timestep
            
            if order == 'second':
                transformer_options['pred_order'] = 'second'
                img_mid = X + (sigma_prev- sigma) / 2 * pred
                sigma_mid = (sigma + (sigma_prev - sigma) / 2)
                pred_mid = model(img_mid, s_in * sigma_mid, **extra_args)

                first_order = (pred_mid - pred) / ((sigma_prev - sigma) / 2)
                pred = -pred + eta_values[i] * (conditional_vector_field + pred)
                
                first_order = -first_order + eta_values[i] * (conditional_vector_field + first_order)
                X = X + (sigma - sigma_prev) * pred + 0.5 * (sigma - sigma_prev) ** 2 * first_order
            else:
                controlled_vector_field = -pred + eta_values[i] * (conditional_vector_field + pred)
                X = X + controlled_vector_field * (sigma - sigma_prev)

            if callback is not None:
                callback({'x': X, 'denoised': X, 'i': i, 'sigma': sigmas[i], 'sigma_hat': sigmas[i]})

        return X
    
    return sample_reverse


class LTXRFForwardODESamplerNode:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": { 
            "gamma": ("FLOAT", {"default": 0.5, "min": 0.0, "max": 100.0, "step": 0.01}),
            "start_step": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1}),
            "end_step": ("INT", {"default": 5, "min": 0, "max": 1000, "step": 1}),
            "gamma_trend": (['linear_decrease', 'linear_increase', 'constant'],)
        }, "optional": {
            "seed": ("INT", {"default": 0, "min": 0, "max": 0xffffffffffffffff }),
            "attn_bank": ("ATTN_BANK",),
            "order": (["first", "second"],),
        }}
    RETURN_TYPES = ("SAMPLER",)
    FUNCTION = "build"

    CATEGORY = "ltxtricks"

    def build(self, gamma, start_step, end_step, gamma_trend, seed=0, attn_bank=None, order="first"):
        sampler = KSAMPLER(get_sample_forward(gamma, start_step, end_step, gamma_trend, seed, attn_bank=attn_bank, order=order))
        
        return (sampler, )


class LTXRFReverseODESamplerNode:
    @classmethod
    def INPUT_TYPES(s):
        return {"required": { 
            "model": ("MODEL",),
            "latent_image": ("LATENT",),
            "eta": ("FLOAT", {"default": 0.8, "min": 0.0, "max": 100.0, "step": 0.01}),
            "start_step": ("INT", {"default": 0, "min": 0, "max": 1000, "step": 1}),
            "end_step": ("INT", {"default": 15, "min": 0, "max": 1000, "step": 1}),
        }, "optional": {
            "eta_trend": (['linear_decrease', 'linear_increase', 'constant'],),
            "attn_inj": ("ATTN_INJ",),
            "order": (["first", "second"],),
        }}
    RETURN_TYPES = ("SAMPLER",)
    FUNCTION = "build"

    CATEGORY = "ltxtricks"

    def build(self, model, latent_image, eta, start_step, end_step, eta_trend='constant', attn_inj=None, order='first'):
        process_latent_in = model.get_model_object("process_latent_in")
        latent_image = process_latent_in(latent_image['samples'])
        sampler = KSAMPLER(get_sample_reverse(latent_image, eta, start_step, end_step, eta_trend, attn_bank=attn_inj, order=order))

        return (sampler, )
